"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""


class Registry:
    mapping = {
        'builder_name_mapping': {},
        'task_name_mapping': {},
        'processor_name_mapping': {},
        'model_name_mapping': {},
        'lr_scheduler_name_mapping': {},
        'runner_name_mapping': {},
        'state': {},
        'paths': {},
    }

    # @classmethod
    # def register_builder(cls, name):
    #     r"""Register a dataset builder to registry with key 'name'

    #     Args:
    #         name: Key with which the builder will be registered.

    #     Usage:

    #         from lavis.common.registry import registry
    #         from lavis.datasets.base_dataset_builder import BaseDatasetBuilder
    #     """

    #     def wrap(builder_cls):
    #         from lavis.datasets.builders.base_dataset_builder import BaseDatasetBuilder

    #         assert issubclass(
    #             builder_cls, BaseDatasetBuilder
    #         ), "All builders must inherit BaseDatasetBuilder class, found {}".format(
    #             builder_cls
    #         )
    #         if name in cls.mapping["builder_name_mapping"]:
    #             raise KeyError(
    #                 "Name '{}' already registered for {}.".format(
    #                     name, cls.mapping["builder_name_mapping"][name]
    #                 )
    #             )
    #         cls.mapping["builder_name_mapping"][name] = builder_cls
    #         return builder_cls

    #     return wrap

    # @classmethod
    # def register_task(cls, name):
    #     r"""Register a task to registry with key 'name'

    #     Args:
    #         name: Key with which the task will be registered.

    #     Usage:

    #         from lavis.common.registry import registry
    #     """

    #     def wrap(task_cls):
    #         from lavis.tasks.base_task import BaseTask

    #         assert issubclass(
    #             task_cls, BaseTask
    #         ), "All tasks must inherit BaseTask class"
    #         if name in cls.mapping["task_name_mapping"]:
    #             raise KeyError(
    #                 "Name '{}' already registered for {}.".format(
    #                     name, cls.mapping["task_name_mapping"][name]
    #                 )
    #             )
    #         cls.mapping["task_name_mapping"][name] = task_cls
    #         return task_cls

    #     return wrap

    @classmethod
    def register_model(cls, name):
        r"""Register a task to registry with key 'name'

        Args:
            name: Key with which the task will be registered.

        Usage:

            from lavis.common.registry import registry
        """

        def wrap(model_cls):
            from ..models import BaseModel

            assert issubclass(model_cls, BaseModel), 'All models must inherit BaseModel class'
            if name in cls.mapping['model_name_mapping']:
                raise KeyError(
                    "Name '{}' already registered for {}.".format(name, cls.mapping['model_name_mapping'][name])
                )
            cls.mapping['model_name_mapping'][name] = model_cls
            return model_cls

        return wrap

    @classmethod
    def register_processor(cls, name):
        r"""Register a processor to registry with key 'name'

        Args:
            name: Key with which the task will be registered.

        Usage:

            from lavis.common.registry import registry
        """

        def wrap(processor_cls):
            from ..processors import BaseProcessor

            assert issubclass(processor_cls, BaseProcessor), 'All processors must inherit BaseProcessor class'
            if name in cls.mapping['processor_name_mapping']:
                raise KeyError(
                    "Name '{}' already registered for {}.".format(name, cls.mapping['processor_name_mapping'][name])
                )
            cls.mapping['processor_name_mapping'][name] = processor_cls
            return processor_cls

        return wrap

    @classmethod
    def register_lr_scheduler(cls, name):
        r"""Register a model to registry with key 'name'

        Args:
            name: Key with which the task will be registered.

        Usage:

            from lavis.common.registry import registry
        """

        def wrap(lr_sched_cls):
            if name in cls.mapping['lr_scheduler_name_mapping']:
                raise KeyError(
                    "Name '{}' already registered for {}.".format(name, cls.mapping['lr_scheduler_name_mapping'][name])
                )
            cls.mapping['lr_scheduler_name_mapping'][name] = lr_sched_cls
            return lr_sched_cls

        return wrap

    @classmethod
    def register_runner(cls, name):
        r"""Register a model to registry with key 'name'

        Args:
            name: Key with which the task will be registered.

        Usage:

            from lavis.common.registry import registry
        """

        def wrap(runner_cls):
            if name in cls.mapping['runner_name_mapping']:
                raise KeyError(
                    "Name '{}' already registered for {}.".format(name, cls.mapping['runner_name_mapping'][name])
                )
            cls.mapping['runner_name_mapping'][name] = runner_cls
            return runner_cls

        return wrap

    @classmethod
    def register_path(cls, name, path):
        r"""Register a path to registry with key 'name'

        Args:
            name: Key with which the path will be registered.

        Usage:

            from lavis.common.registry import registry
        """
        assert isinstance(path, str), 'All path must be str.'
        if name in cls.mapping['paths']:
            raise KeyError("Name '{}' already registered.".format(name))
        cls.mapping['paths'][name] = path

    @classmethod
    def register(cls, name, obj):
        r"""Register an item to registry with key 'name'

        Args:
            name: Key with which the item will be registered.

        Usage::

            from lavis.common.registry import registry

            registry.register("config", {})
        """
        path = name.split('.')
        current = cls.mapping['state']

        for part in path[:-1]:
            if part not in current:
                current[part] = {}
            current = current[part]

        current[path[-1]] = obj

    # @classmethod
    # def get_trainer_class(cls, name):
    #     return cls.mapping["trainer_name_mapping"].get(name, None)

    @classmethod
    def get_builder_class(cls, name):
        return cls.mapping['builder_name_mapping'].get(name, None)

    @classmethod
    def get_model_class(cls, name):
        return cls.mapping['model_name_mapping'].get(name, None)

    @classmethod
    def get_task_class(cls, name):
        return cls.mapping['task_name_mapping'].get(name, None)

    @classmethod
    def get_processor_class(cls, name):
        return cls.mapping['processor_name_mapping'].get(name, None)

    @classmethod
    def get_lr_scheduler_class(cls, name):
        return cls.mapping['lr_scheduler_name_mapping'].get(name, None)

    @classmethod
    def get_runner_class(cls, name):
        return cls.mapping['runner_name_mapping'].get(name, None)

    @classmethod
    def list_runners(cls):
        return sorted(cls.mapping['runner_name_mapping'].keys())

    @classmethod
    def list_models(cls):
        return sorted(cls.mapping['model_name_mapping'].keys())

    @classmethod
    def list_tasks(cls):
        return sorted(cls.mapping['task_name_mapping'].keys())

    @classmethod
    def list_processors(cls):
        return sorted(cls.mapping['processor_name_mapping'].keys())

    @classmethod
    def list_lr_schedulers(cls):
        return sorted(cls.mapping['lr_scheduler_name_mapping'].keys())

    @classmethod
    def list_datasets(cls):
        return sorted(cls.mapping['builder_name_mapping'].keys())

    @classmethod
    def get_path(cls, name):
        return cls.mapping['paths'].get(name, None)

    @classmethod
    def get(cls, name, default=None, no_warning=False):
        r"""Get an item from registry with key 'name'

        Args:
            name (string): Key whose value needs to be retrieved.
            default: If passed and key is not in registry, default value will
                     be returned with a warning. Default: None
            no_warning (bool): If passed as True, warning when key doesn't exist
                               will not be generated. Useful for MMF's
                               internal operations. Default: False
        """
        original_name = name
        name = name.split('.')
        value = cls.mapping['state']
        for subname in name:
            value = value.get(subname, default)
            if value is default:
                break

        if ('writer' in cls.mapping['state'] and value == default and no_warning is False):
            cls.mapping['state']['writer'].warning(
                'Key {} is not present in registry, returning default value '
                'of {}'.format(original_name, default)
            )
        return value

    @classmethod
    def unregister(cls, name):
        r"""Remove an item from registry with key 'name'

        Args:
            name: Key which needs to be removed.
        Usage::

            from mmf.common.registry import registry

            config = registry.unregister("config")
        """
        return cls.mapping['state'].pop(name, None)


registry = Registry()
